import os
import unittest
import ray
from transformers import AutoTokenizer
import torch
import tempfile
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig
from xtuner.v1.ray.environment import SingleTurnEnvironment
from xtuner.v1.datasets import RLTokenizeFnConfig
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.ray.rollout.controller import RolloutController
from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker

MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] 
TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
resource_map = {"npu": "NPU", "cuda": "GPU"}
@ray.remote
class MockTimeoutRolloutController(RolloutController):
    def _get_worker_cls(self):
        return ray.remote(MockTimeoutRolloutWorker)
    def deactivate_worker_by_url(self, url):
        pass
@ray.remote
class MockRequestErrorRolloutController(RolloutController):
    def _get_worker_cls(self):
        return ray.remote(MockRequestErrorRolloutWorker)
    def deactivate_worker_by_url(self, url):
        pass
@ray.remote    
class MockClientErrorRolloutController(RolloutController):
    def _get_worker_cls(self):
        return ray.remote(MockClientErrorRolloutWorker)
    def deactivate_worker_by_url(self, url):
        pass
@ray.remote
class MockServerErrorRolloutController(RolloutController):
    def _get_worker_cls(self):
        return ray.remote(MockServerErrorRolloutWorker)
    
    def deactivate_worker_by_url(self, url):
        pass

class TestMockRollout(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        os.environ["XTUNER_USE_FA3"] = "1"


    @classmethod
    def tearDownClass(cls):
        del os.environ["XTUNER_USE_FA3"]

    def setUp(self):
        ray.init(num_cpus=80, ignore_reinit_error=True)
        self.global_batch_size = 3
        self.max_prompt_length = 4096
        self.max_response_length = 128
        self.max_concurrent = 3
        self.max_retry_times = 3
        self.temp_dir = tempfile.TemporaryDirectory()
        self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
        
        self.resources_cfg = AcceleratorResourcesConfig(
            accelerator=resource_map[torch.accelerator.current_accelerator().type],
            num_workers=8,
            num_cpus_per_worker=8,
            cpu_memory_per_worker=16 * 1024**3,  # 16 GB
        )
        self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)

        self.rollout_cfg = RolloutConfig(
            env="test_mock_rollout",
            model_path=MODEL_PATH,
            model_name=os.path.basename(MODEL_PATH).lower(),
            tokenizer_path=MODEL_PATH,
            tensor_parallel_size=1,
            context_length=self.max_prompt_length + self.max_response_length,
            max_retry_per_worker=2,
            worker_log_dir=self.worker_log_dir
        )
        tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

        self.dataflow_cfg = DataFlowConfig(
            max_concurrent=self.max_concurrent,
            global_batch_size=self.global_batch_size,
            max_retry_times=self.max_retry_times,
            worker_log_dir=self.worker_log_dir  
        )
        train_dataset_cfg = [{
            "dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH),
            "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length),
        }]
        dataloader_cfg = DataloaderConfig(
            collator='fake_collator',
            pack_level='none',
            group_by_length=False,
        )
        self.replay_buffer_cfg = ReplayBufferConfig(
            dataset_cfg=train_dataset_cfg,
            dataloader_cfg=dataloader_cfg,
            tokenizer=tokenizer,
            worker_log_dir=self.worker_log_dir
        )

    def tearDown(self):
        ray.shutdown()
        self.temp_dir.cleanup()

    def _run_mock_test(self, mock_controller_cls, error_name: str):
        rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg)
        self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller)
        self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env)
        
        completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3))

        status = ray.get(self.test_dataflow.get_replaybuffer_status.remote())
        print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}")
        self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.")
        self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.")
        self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.")
        ray.get(self.test_env.shutdown.remote())

    @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
    def test_rollout_with_timeout_mock(self):
        self._run_mock_test(MockTimeoutRolloutController, "timeout")

    @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")  
    def test_rollout_with_request_error_mock(self):
        self._run_mock_test(MockRequestErrorRolloutController, "request error")
    
    @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
    def test_rollout_with_client_error_mock(self):
        self._run_mock_test(MockClientErrorRolloutController, "client error")
    
    @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
    def test_rollout_with_server_error_mock(self):
        self._run_mock_test(MockServerErrorRolloutController, "server error")

if __name__ == "__main__":
    unittest.main()